%% Tract-specific analysis of ALS spinal cord data
%  Regents of the University of Minnesota.
% This software is licensed under The MIT License.  
% How to cite: If you use this software for your work, please cite the following paper published in Communications Biology: 
% Pisharady, P.K., Eberly, L.E., Cheong, I. et al. Tract-specific analysis improves sensitivity of spinal cord diffusion MRI 
% to cross-sectional and longitudinal changes in amyotrophic lateral sclerosis. Commun Biol 3, 370 (2020). https://doi.org/10.1038/s42003-020-1093-z

% To run this code you need to download and add to the path the code
% available at https://github.com/johncolby/along-tract-stats (see the
% reference Colby et al., Neuroimage, 2012 in the main paper)

clear all
close all
atlas_thr=0.1; %threshold for atlas mask
%-- Uncomment 'atlas_to_use' below based on which tract to be analyzed
%atlas_to_use={'00','01','02','03','04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25','26','27','28','29','30','31','32','33','34','35','36'};%Whole cord
atlas_to_use={'00','01','02','03','04','05','06','07','08','09','10','11','12','13','14','15','16','17','18','19','20','21','22','23','24','25','26','27','28','29'};
%# white matter
%atlas_to_use={'04','05','08','09','10','11','16','17','18','19','20','21','22','23','24','25','26','27'};%Descending tracts
%atlas_to_use={'04','05'};% LCST
%atlas_to_use={'08','09'};% RST
%atlas_to_use={'12','13'}; % SL
%atlas_to_use={'22','23'}; % ACST
%atlas_to_use={'00','01','02','03'};% PC
%atlas_to_use={'30','31','32','33','34','35'}; % gray matter
%atlas_to_use={'00','01','02','03','06', '07' '12','13'}; % Ascending tracts 
%-- This is a manual correction applied to match the C2-C6 levels of all the subjects
lslice=[0 2 0 -3 0 3 0 5 0 1 1 2 1 2 2 -4 -2 0 2 0   1 2 0  1 1  2 1 1 1 1 0 0 2 -1 1 2 2 1 0 1]+2+4;

feature_to_use='_FA.nii';
exDir='/Users/pramod/CMRR/Data/ALS';% Data directory
fileID = fopen(fullfile(exDir,'between_group_spine/Demographics_pvc.txt')); % Demographics file 

dataDir  = fullfile(exDir, 'pcv_spine_fsl'); % this is where the .trk and .nii files reside
atlasDir = fullfile(exDir, 'ALS_sct'); % this is where the SCT based spinal cord segmentations reside 
C = textscan(fileID,'%s %s %s');
fclose(fileID);
len_C=length(C{1});
con_FA=[];
pat_FA=[];
%-- To save the extracted features
fid1 = fopen(fullfile(exDir, 'between_group_spine/trk_props.txt'), 'wt');
fprintf(fid1, 'ID\tHemisphere\tTract\tStreamlines');
fid2 = fopen(fullfile(exDir, 'between_group_spine/trk_data.txt'), 'wt');
fprintf(fid2, 'ID\tPoint\tHemisphere\tTract\tFA\tSD');

for m=1:len_C 
    C{1}(m)
    tracks_thr=[];
    trkPath         = fullfile(dataDir, strcat('tract_',C{1}{m},'_spline.trk')); %tractography files
    volPath         = fullfile(dataDir, strcat(C{1}{m},feature_to_use)); % FA etc. data files
    atlasPath       = fullfile(atlasDir, strcat('/',C{1}{m},'/label/atlas/')); % Atlas based segmentations
    maskPath        = fullfile(atlasDir, strcat('/',C{1}{m},'/mask_dwi_mean.nii')); % spinal cord masks 
    [header, tracks] = trk_read(trkPath);
    kk=1;
    volume          = read_avw(volPath);
    atlas           = read_atlas(atlas_to_use,atlasPath);
    mask            = read_avw(maskPath);
    linear_ind = find(mask); 
    [maskx,masky,maskz]=ind2sub(size(mask),linear_ind);
    atlas_big=zeros(size(volume));
    atlas_big(min(maskx):(min(maskx)+size(atlas,1)-1),min(masky):(min(masky)+size(atlas,2)-1),min(maskz):(min(maskz)+size(atlas,3)-1) )=atlas;
     atlas_big(atlas_big<atlas_thr)=0;
    for xx=1:size(volume,1)
         for yy=1:size(volume,2)
             for zz=1:size(volume,3)
                 if (atlas_big(xx,yy,zz)>=atlas_thr)
                 volume(xx,yy,zz)=volume(xx,yy,zz);
                 else
                 volume(xx,yy,zz)=0;
                 end
             end
         end
     end
    no_fib(m)=length(tracks);
    tract_lengths=[];
    for i=1:length(tracks)
        tract_lengths(i)=tracks(i).nPoints; 
    end
    for i=1:length(tracks)
        max_len=max(tract_lengths);
       if   (tracks(i).nPoints>40) % discard partial tracts
        tracks_thr(kk).nPoints=tracks(i).nPoints;
        tracks_temp=tracks(i).matrix;
        tracks_temp(:,2)=header.dim(2)*header.voxel_size(2)-tracks_temp(:,2);
        tracks_thr(kk).matrix=tracks_temp;
        kk=kk+1;
       end
    end
    tracks=tracks_thr;
    header.n_count=size(tracks,2);
    streamlines(m)=size(tracks,2);

    tracks_interp=tracks;
    tracks_interp           = trk_flip(header, tracks_interp, [68 50 4]);
    tracks_interp_str        = trk_restruc_new(tracks_interp);
    [header_sc tracks_sc]    = trk_add_sc_new(header, tracks_interp_str, volume, 'FA');
    [scalar_mean scalar_sd]  = trk_mean_sc_new(header_sc, tracks_sc);
    
    scalar_mean=nonzeros(scalar_mean);
    c7_thr=round(lslice(m)*(length(scalar_mean)/28));
    scalar_mean=scalar_mean(c7_thr:end);
    scalar_mean=timewarp(scalar_mean,75);
    scalar_mean=fliplr(scalar_mean);
    scalar_sd=nonzeros(scalar_sd);
    scalar_sd=timewarp(scalar_sd,90);

     if (C{3}{m}(1)=='c')
         plot(scalar_mean, 'Color',[0.3,.6,1],'LineWidth',.1);
         ylim([.2 .8]) % for FA
         %ylim([.0005 .0011]) % for RD etc.
         title(C{1}{m})
         con_FA=[con_FA scalar_mean'];
         hold on
     else%if (C{3}{m}(1)=='p')
         plot(scalar_mean, 'Color',[.9,.1,.3],'LineWidth',.1);
         ylim([.2 .8]) % for FA
        %ylim([.0005 .0011]) % for RD etc.
         title(C{1}{m})
         pat_FA=[pat_FA scalar_mean'];
     end
     xlabel('Position along tract (%)','FontSize',14,'Color','k');
     ylabel('FA','FontSize',14,'Color','k');
     hold on
     
     fprintf(fid1, '\n%s\t%s\t%s\t%d', C{1}{m}, 'L','AF', header.n_count);
  for iPt=1:length(scalar_mean)
     fprintf(fid2, '\n%s\t%d\t%s\t%s\t%0.4f\t%0.4f', C{1}{m}, iPt, 'L','AF', scalar_mean(iPt), scalar_sd(iPt));
  end
end
   
FA_all=[con_FA pat_FA];
diffusionFiles{1}=FA_all;
save('diffusiondata_sct','diffusionFiles');% The data saved here is used for further statistical analysis
figure
plot(mean(con_FA'),'Color',[0.3,.6,1],'LineWidth',4);
hold on
plot(mean(pat_FA'),'Color',[.9,.1,.3],'LineWidth',4);
legend('Control','ALS');%
set(gca,'Xcolor',[0.5 0.5 0.5]);
set(gca,'Ycolor',[0.5 0.5 0.5]);
grid on
xlabel('Position along tract (C2-C6)','FontSize',14,'Color','k');
ylabel('mean FA','FontSize',14,'Color','k'); % RD (mm^2/s)
ylim([.3 .6])
%ylim([.0008 .0011])


meanconFA=mean(con_FA');
stdconFA=std(con_FA');
meanpatFA=mean(pat_FA');
stdpatFA=std(pat_FA');

con_FA_T=con_FA';
pat_FA_T=pat_FA';

figure
boxplot(con_FA_T(:,1:3:end),'Color',[0.3,.6,1],'PlotStyle','compact','BoxStyle','outline','Widths',.3);
hold on
boxplot(pat_FA_T(:,1:3:end),'Color',[.9,.1,.3],'PlotStyle','compact','BoxStyle','outline','Widths',.3);
ylim([.2 .8])
legend('Control','ALS');%
grid on
xlabel('Position along tract (C2-C6)','FontSize',14,'Color','k');
ylabel('mean FA','FontSize',14,'Color','k');
title('Only white matter','FontSize',16,'Color','k');


figure
errorbar(1:length(meanconFA(1:3:end)),meanconFA(1:3:end),zeros(1,length(meanconFA(1:3:end))),stdconFA(1:3:end),'Color',[0.3,.6,1],'LineWidth',1) % for FA
%errorbar(1:length(meanconFA(1:3:end)),meanconFA(1:3:end),stdconFA(1:3:end),zeros(1,length(meanconFA(1:3:end))),'Color',[0.3,.6,1],'LineWidth',1) % for RD etc
hold on
errorbar(1:length(meanpatFA(1:3:end)),meanpatFA(1:3:end),stdpatFA(1:3:end),zeros(1,length(meanconFA(1:3:end))),'Color',[.9,.1,.3],'LineWidth',1); % for FA
%errorbar(1:length(meanpatFA(1:3:end)),meanpatFA(1:3:end),zeros(1,length(meanconFA(1:3:end))),stdpatFA(1:3:end),'Color',[.9,.1,.3],'LineWidth',1); % for RD etc
ylim([.3 .7])
%ylim([.0004 .0014])
xlim([0 26])
grid on
xlabel('Position along tract (C2-C6)','FontSize',14,'Color','k');
ylabel('mean FA','FontSize',14,'Color','k'); 
%ylabel('mean RD (mm^2/s)','FontSize',14,'Color','k');
legend('Control','ALS');%

mean_along_tract_pat=mean(pat_FA)';
fa_all=mean_along_tract_pat;
fa_C2=mean(pat_FA(1:15,:))';
fa_C3=mean(pat_FA(16:30,:))';
fa_C4=mean(pat_FA(31:45,:))';
fa_C5=mean(pat_FA(46:60,:))';
fa_C6=mean(pat_FA(61:end,:))';
save('meanFAs','fa_all','fa_C6','fa_C5','fa_C4','fa_C3','fa_C2');% This data is used for the correlation analysis
dlmwrite('whole_cord_FA.txt',[con_FA';pat_FA'],'delimiter',' ')